# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
from semilearn.core.algorithmbase import AlgorithmBase
from semilearn.core.utils import ALGORITHMS
from semilearn.algorithms.hooks import PseudoLabelingHook, FixedThresholdingHook
from semilearn.algorithms.utils import SSL_Argument, str2bool
import numpy as np
import torch.nn.functional as F
import copy
@ALGORITHMS.register('fixmatch')
class FixMatch(AlgorithmBase):

    """
        FixMatch algorithm (https://arxiv.org/abs/2001.07685).

        Args:
            - args (`argparse`):
                algorithm arguments
            - net_builder (`callable`):
                network loading function
            - tb_log (`TBLog`):
                tensorboard logger
            - logger (`logging.Logger`):
                logger to use
            - T (`float`):
                Temperature for pseudo-label sharpening
            - p_cutoff(`float`):
                Confidence threshold for generating pseudo-labels
            - hard_label (`bool`, *optional*, default to `False`):
                If True, targets have [Batch size] shape with int values. If False, the target is vector
    """
    def __init__(self, args, net_builder, tb_log=None, logger=None):
        super().__init__(args, net_builder, tb_log, logger) 
        # fixmatch specificed arguments
        self.init(T=args.T, p_cutoff=args.p_cutoff, hard_label=args.hard_label)
    
    def init(self, T, p_cutoff, hard_label=True):
        self.T = T
        self.p_cutoff = p_cutoff
        self.use_hard_label = hard_label
    
    def set_hooks(self):
        self.register_hook(PseudoLabelingHook(), "PseudoLabelingHook")
        self.register_hook(FixedThresholdingHook(), "MaskingHook")
        super().set_hooks()
    
    @torch.no_grad()
    def eval_training(self,images, labels,epoch=0, tb=True):
        emotion_entropy=[[1,2,3,4,8,7,6,5],
            [2,1,2,3,7,8,7,6],
            [3,2,1,2,6,7,8,7],
            [4,3,2,1,5,6,7,8],
            [8,7,6,5,1,2,3,4],
            [7,8,7,6,2,1,2,3],
            [6,7,8,7,3,2,1,2],
            [5,6,7,8,4,3,2,1]]
        emotion_entropy_noacc=copy.deepcopy(emotion_entropy)
        for i in range(8):
            for j in range(8):
                if i!=j:
                    emotion_entropy_noacc[i][j]-=1
        emotion_entropy=np.array(emotion_entropy)
        # distance=1/(emotion_entropy*emotion_entropy)
        distance=1/(emotion_entropy)
        distance= torch.tensor(distance,dtype=torch.float32)
        distance=distance.cuda(self.args.gpu) 
        
        emotion_entropy_noacc=np.array(emotion_entropy_noacc)
        distance_noacc=1/(emotion_entropy_noacc)
        distance_noacc= torch.tensor(distance_noacc,dtype=torch.float32)
        distance_noacc=distance_noacc.cuda(self.args.gpu) 
        
        confusion_matrix=torch.zeros_like(distance)
        emotion_entropy_score=0
        self.model.eval()
        test_loss = 0.0 # cost function error
        correct = 0.0
        if True:
            images = images.cuda(self.args.gpu)
            labels = labels.cuda(self.args.gpu)
        
        outputs = self.model(images)
            
        
        _, preds =  outputs['logits'].max(1)
        acc_num=(preds==labels).sum()
        prob = F.softmax(outputs['logits'],dim=0)
            
     
        
        emotion_entropy_score+=torch.sum(distance[labels]*prob,dim=1).sum()
        correct += preds.eq(labels).sum()
        for i in range(len(labels)):
            confusion_matrix[labels[i]][preds[i]]+=1
        
    
        acc_EE=(confusion_matrix*distance).sum()
        noacc_confusion_matrix=confusion_matrix.clone()
        for i in range(len(noacc_confusion_matrix)):
            noacc_confusion_matrix[i,i]=0
        
        noacc_EE=(noacc_confusion_matrix*distance_noacc).sum()
        # print('Evaluating Network.....')
        # print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f},Emotion_Entropy: {:.4f},acc_EE: {:.4f},noacc_EE: {:.4f},Time consumed:{:.2f}s'.format(
        #     epoch,
        #     test_loss / len(sll_test_loader.dataset),
        #     correct.float() / len(sll_test_loader.dataset),
        #     emotion_entropy_score/len(sll_test_loader.dataset),
        #     acc_EE / len(sll_test_loader.dataset),
        #     noacc_EE / (len(sll_test_loader.dataset)-correct.float()),
        #     finish - start
        # ))

        return acc_EE,noacc_EE,len(labels),(len(labels)-acc_num)
      
    @torch.no_grad()
    def pso_eval_training(self,pseudo_label, logits_x_ulb_s,epoch=0, tb=True):
        emotion_entropy=[[1,2,3,4,8,7,6,5],
            [2,1,2,3,7,8,7,6],
            [3,2,1,2,6,7,8,7],
            [4,3,2,1,5,6,7,8],
            [8,7,6,5,1,2,3,4],
            [7,8,7,6,2,1,2,3],
            [6,7,8,7,3,2,1,2],
            [5,6,7,8,4,3,2,1]]
        emotion_entropy_noacc=copy.deepcopy(emotion_entropy)
        for i in range(8):
            for j in range(8):
                if i!=j:
                    emotion_entropy_noacc[i][j]-=1
        emotion_entropy=np.array(emotion_entropy)
        # distance=1/(emotion_entropy*emotion_entropy)
        distance=1/(emotion_entropy)
        distance= torch.tensor(distance,dtype=torch.float32)
        distance=distance.cuda(self.args.gpu) 
        
        emotion_entropy_noacc=np.array(emotion_entropy_noacc)
        distance_noacc=1/(emotion_entropy_noacc)
        distance_noacc= torch.tensor(distance_noacc,dtype=torch.float32)
        distance_noacc=distance_noacc.cuda(self.args.gpu) 
        
        confusion_matrix=torch.zeros_like(distance)
        emotion_entropy_score=0
        test_loss = 0.0 # cost function error
        correct = 0.0
     
            
        
        _, preds =  logits_x_ulb_s.max(1)
        acc_num=(preds==pseudo_label).sum()
        prob = F.softmax(logits_x_ulb_s,dim=0)
            
     
        
        emotion_entropy_score+=torch.sum(distance[pseudo_label]*prob,dim=1).sum()
        correct += preds.eq(pseudo_label).sum()
        for i in range(len(pseudo_label)):
            confusion_matrix[pseudo_label[i]][preds[i]]+=1
        
    
        acc_EE=(confusion_matrix*distance).sum()
        noacc_confusion_matrix=confusion_matrix.clone()
        for i in range(len(noacc_confusion_matrix)):
            noacc_confusion_matrix[i,i]=0
        
        noacc_EE=(noacc_confusion_matrix*distance_noacc).sum()
        # print('Evaluating Network.....')
        # print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f},Emotion_Entropy: {:.4f},acc_EE: {:.4f},noacc_EE: {:.4f},Time consumed:{:.2f}s'.format(
        #     epoch,
        #     test_loss / len(sll_test_loader.dataset),
        #     correct.float() / len(sll_test_loader.dataset),
        #     emotion_entropy_score/len(sll_test_loader.dataset),
        #     acc_EE / len(sll_test_loader.dataset),
        #     noacc_EE / (len(sll_test_loader.dataset)-correct.float()),
        #     finish - start
        # ))

        return acc_EE,noacc_EE,len(pseudo_label),(len(pseudo_label)-acc_num)
          
    
    def compute_entropy(self,logits_x_ulb_s):
        # emotion_entropy=[[1,2,3,4,8,7,6,5],
        #     [2,1,2,3,7,8,7,6],
        #     [3,2,1,2,6,7,8,7],
        #     [4,3,2,1,5,6,7,8],
        #     [8,7,6,5,1,2,3,4],
        #     [7,8,7,6,2,1,2,3],
        #     [6,7,8,7,3,2,1,2],
        #     [5,6,7,8,4,3,2,1]]
        emotion_entropy=[
            [8,7,6,5,1,2,3,4],
            [7,8,7,6,2,1,2,3],
            [6,7,8,7,3,2,1,2],
            [5,6,7,8,4,3,2,1],
            [1,2,3,4,8,7,6,5],
            [2,1,2,3,7,8,7,6],
            [3,2,1,2,6,7,8,7],
            [4,3,2,1,5,6,7,8]]
        emotion_entropy_noacc=copy.deepcopy(emotion_entropy)
        for i in range(8):
            for j in range(8):
                if i!=j:
                    emotion_entropy_noacc[i][j]-=1
        emotion_entropy=np.array(emotion_entropy)
        # distance=1/(emotion_entropy*emotion_entropy)
        distance=1/(emotion_entropy)
        distance= torch.tensor(distance,dtype=torch.float32)
        distance=distance.cuda(self.args.gpu) 
        
        emotion_entropy_noacc=np.array(emotion_entropy_noacc)
        distance_noacc=1/(emotion_entropy_noacc)
        distance_noacc= torch.tensor(distance_noacc,dtype=torch.float32)
        distance_noacc=distance_noacc.cuda(self.args.gpu) 
        
        confusion_matrix=torch.zeros_like(distance)
        emotion_entropy_score=0
        test_loss = 0.0 # cost function error
        correct = 0.0
     
            
        
        _, preds =  logits_x_ulb_s.max(1)
        acc_num=(preds==pseudo_label).sum()
        prob = F.softmax(logits_x_ulb_s,dim=0)
            
     
        
        emotion_entropy_score+=torch.sum(distance[pseudo_label]*prob,dim=1).sum()
        correct += preds.eq(pseudo_label).sum()
        for i in range(len(pseudo_label)):
            confusion_matrix[pseudo_label[i]][preds[i]]+=1
        
    
        acc_EE=(confusion_matrix*distance).sum()
        noacc_confusion_matrix=confusion_matrix.clone()
        for i in range(len(noacc_confusion_matrix)):
            noacc_confusion_matrix[i,i]=0
        
        noacc_EE=(noacc_confusion_matrix*distance_noacc).sum()
        # print('Evaluating Network.....')
        # print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f},Emotion_Entropy: {:.4f},acc_EE: {:.4f},noacc_EE: {:.4f},Time consumed:{:.2f}s'.format(
        #     epoch,
        #     test_loss / len(sll_test_loader.dataset),
        #     correct.float() / len(sll_test_loader.dataset),
        #     emotion_entropy_score/len(sll_test_loader.dataset),
        #     acc_EE / len(sll_test_loader.dataset),
        #     noacc_EE / (len(sll_test_loader.dataset)-correct.float()),
        #     finish - start
        # ))

        return acc_EE,noacc_EE,len(pseudo_label),(len(pseudo_label)-acc_num)
          
    
    
    def train_step(self, x_lb, y_lb, x_ulb_w, x_ulb_s):
        num_lb = y_lb.shape[0]
        ecc,emc,num,num_emc=self.eval_training(x_lb,y_lb)
        # print("ecc:",ecc)
        # print("emc:",emc)
        self.model.train()
        # inference and calculate sup/unsup losses
        with self.amp_cm():
            if self.use_cat:
                inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s))
                outputs = self.model(inputs)
                
                logits_x_lb = outputs['logits'][:num_lb]
                logits_x_ulb_w, logits_x_ulb_s = outputs['logits'][num_lb:].chunk(2)
                feats_x_lb = outputs['feat'][:num_lb]
                feats_x_ulb_w, feats_x_ulb_s = outputs['feat'][num_lb:].chunk(2)
            else:
                outs_x_lb = self.model(x_lb) 
                logits_x_lb = outs_x_lb['logits']
                feats_x_lb = outs_x_lb['feat']
                outs_x_ulb_s = self.model(x_ulb_s)
                logits_x_ulb_s = outs_x_ulb_s['logits']
                feats_x_ulb_s = outs_x_ulb_s['feat']
                with torch.no_grad():
                    outs_x_ulb_w = self.model(x_ulb_w)
                    logits_x_ulb_w = outs_x_ulb_w['logits']
                    feats_x_ulb_w = outs_x_ulb_w['feat']
            feat_dict = {'x_lb':feats_x_lb, 'x_ulb_w':feats_x_ulb_w, 'x_ulb_s':feats_x_ulb_s}

            sup_loss = self.ce_loss(logits_x_lb, y_lb, reduction='mean')
            
            # probs_x_ulb_w = torch.softmax(logits_x_ulb_w, dim=-1)
            probs_x_ulb_w = self.compute_prob(logits_x_ulb_w.detach())
            
            # if distribution alignment hook is registered, call it 
            # this is implemented for imbalanced algorithm - CReST
            if self.registered_hook("DistAlignHook"):
                probs_x_ulb_w = self.call_hook("dist_align", "DistAlignHook", probs_x_ulb=probs_x_ulb_w.detach())

            # compute mask
            mask = self.call_hook("masking", "MaskingHook", logits_x_ulb=probs_x_ulb_w, softmax_x_ulb=False)

            # generate unlabeled targets using pseudo label hook
            pseudo_label = self.call_hook("gen_ulb_targets", "PseudoLabelingHook", 
                                          logits=probs_x_ulb_w,
                                          use_hard_label=self.use_hard_label,
                                          T=self.T,
                                          softmax=False)

            unsup_loss = self.consistency_loss(logits_x_ulb_s,
                                               pseudo_label,
                                               'ce',
                                               mask=mask)

            total_loss = sup_loss + self.lambda_u * unsup_loss
        ecc_pso,emc_pso,num_pso,num_emc_pso=self.pso_eval_training(pseudo_label,logits_x_ulb_s)
        # ecc=float(ecc+ecc_pso)/(num+num_pso)
        # emc=float(emc+emc_pso)/(num_emc+num_emc_pso)
        ecc=float(ecc_pso)/(num_pso)
        emc=float(emc_pso)/(num_emc_pso)
        # print(ecc)
        # print(emc)
        
        out_dict = self.process_out_dict(loss=total_loss, feat=feat_dict,ecc=float(ecc),emc=float(emc))
        #out_dict = self.process_out_dict(loss=total_loss, feat=feat_dict)
        log_dict = self.process_log_dict(sup_loss=sup_loss.item(), 
                                         unsup_loss=unsup_loss.item(), 
                                         total_loss=total_loss.item(), 
                                         util_ratio=mask.float().mean().item())
        return out_dict, log_dict
        

    @staticmethod
    def get_argument():
        return [
            SSL_Argument('--hard_label', str2bool, True),
            SSL_Argument('--T', float, 0.5),
            SSL_Argument('--p_cutoff', float, 0.95),
        ]
